# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Produce the dataset for mnist. """

import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2

from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
from mindspore.communication.management import init, get_rank, get_group_size
from mindvision.classification.utils.build_trans import build_trans

index2label = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9'}


def mnist_dataset(dataset_path,
                  states,
                  batch_size=32,
                  shuffle=False,
                  repeat_num=1,
                  resize_height=32,
                  resize_width=32,
                  config=None,
                  distribute=False,
                  enable_cache=False,
                  cache_session_id=None,
                  num_parallel_workers=1):
    """
    create dataset for train or test
    """

    target = "GPU"
    if target == "Ascend":
        device_num, rank_id = _get_rank_info()
    else:
        if distribute:
            init()
            rank_id = get_rank()
            device_num = get_group_size()
        else:
            device_num = 1

    if device_num == 1:
        mnist_ds = ds.MnistDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
    else:
        mnist_ds = ds.MnistDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
                                   num_shards=device_num, shard_id=rank_id)

    if not config:
        rescale = 1.0 / 255.0
        shift = 0.0
        rescale_nml = 1 / 0.3081
        shift_nml = -1 * 0.1307 / 0.3081

        # define map operations
        trans = [
            C.Resize((resize_height, resize_width), interpolation=Inter.LINEAR),  # Bilinear# mode
            C.Rescale(rescale, shift),
            C.Rescale(rescale_nml, shift_nml),
            C.HWC2CHW(),
        ]
    else:
        trans = build_trans(config.train_data.transforms)

    type_cast_op = C2.TypeCast(mstype.int32)

    # apply map operations on images
    mnist_ds = mnist_ds.map(operations=trans, input_columns="image", num_parallel_workers=num_parallel_workers)

    # only enable cache for eval
    if states == "train":
        enable_cache = False

    if enable_cache:
        if not cache_session_id:
            raise ValueError("A cache session_id must be provided to use cache.")
        eval_cache = ds.DatasetCache(session_id=int(cache_session_id), size=0)
        mnist_ds = mnist_ds.map(operations=type_cast_op,
                                input_columns="label",
                                num_parallel_workers=num_parallel_workers,
                                cache=eval_cache)
    else:
        mnist_ds = mnist_ds.map(operations=type_cast_op,
                                input_columns="label",
                                num_parallel_workers=num_parallel_workers)

    # apply DatasetOps
    buffer_size = 10000
    ds.config.set_seed(58)
    mnist_ds = mnist_ds.shuffle(
        buffer_size=buffer_size)  # 10000 as in LeNet train script
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    mnist_ds = mnist_ds.repeat(repeat_num)

    if states == "infer":
        return mnist_ds, index2label

    return mnist_ds


def _get_rank_info():
    """
    get rank size and rank id
    """
    rank_size = int(os.environ.get("RANK_SIZE", 1))

    if rank_size > 1:
        rank_size = get_group_size()
        rank_id = get_rank()
    else:
        rank_size = 1
        rank_id = 0

    return rank_size, rank_id
